from __future__ import absolute_import, division, print_function

import numpy
import numpy as np
from numba import jit, njit


@njit
def cal_distance(mat, Xbin, Ybin, Zbin):
    indI = 0
    indJ = 0
    for i in range(Ybin):
        for j in range(Xbin):
            for k in range(Zbin):
                for ii in range(Ybin):
                    for jj in range(Xbin):
                        for kk in range(Zbin):
                            mat[indI, indJ] = numpy.sqrt(
                                (j - jj) ** 2 + (i - ii) ** 2 + (k - kk) ** 2
                            )
                            indI += 1
                indI = 0
                indJ += 1
    return mat


@jit(nopython=True)
def fill_alignment_matrix(A, B, sub_matrix, gap_value):
    n, m = len(A), len(B)
    F = np.zeros((n + 1, m + 1), dtype=np.float32)
    F = numpy.zeros((n + 1, m + 1))
    for i in range(n + 1):
        F[i, 0] = gap_value * (i + 1)
    for j in range(m + 1):
        F[0, j] = gap_value * (j + 1)

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            match = F[i - 1, j - 1] + sub_matrix[A[i - 1], B[j - 1]]
            delete = F[i - 1, j] + gap_value
            insert = F[i, j - 1] + gap_value
            F[i, j] = max(match, delete, insert)
    return F


class ScanMatch(object):
    """
    ScanMatch Object.
    """

    def __init__(self, **kw):
        self.Xres = 1024
        self.Yres = 768
        self.Zres = 512
        self.Xbin = 8
        self.Ybin = 6
        self.Zbin = 6
        self.Threshold = 3.5
        self.GapValue = 0.0
        self.TempBin = 0.0
        self.Offset = (0, 0, 0)

        for k in kw.keys():
            if k == "Xres":
                self.Xres = kw[k]
            elif k == "Yres":
                self.Yres = kw[k]
            elif k == "Zres":
                self.Zres = kw[k]
            elif k == "Xbin":
                self.Xbin = kw[k]
            elif k == "Ybin":
                self.Ybin = kw[k]
            elif k == "Zbin":
                self.Zbin = kw[k]
            elif k == "Threshold":
                self.Threshold = kw[k]
            elif k == "GapValue":
                self.GapValue = kw[k]
            elif k == "TempBin":
                self.TempBin = kw[k]
            elif k == "Offset":
                self.Offset = kw[k]
            else:
                raise ValueError("Unknown parameter: %s." % k)

        self.int_vectorize = numpy.vectorize(int)

        self.CreateSubMatrix()
        self.GridMask()

    def CreateSubMatrix(self, Threshold=None):
        if Threshold is not None:
            self.Threshold = Threshold
        mat = numpy.zeros(
            (self.Xbin * self.Ybin * self.Zbin, self.Xbin * self.Ybin * self.Zbin)
        )

        mat = cal_distance(mat, self.Xbin, self.Ybin, self.Zbin)
        max_sub = numpy.max(mat)

        self.SubMatrix = numpy.abs(mat - max_sub) - (max_sub - self.Threshold)

    def GridMask(self):
        # not entirely masking, this will just put the coordinate into a particular positional bin.
        a = numpy.reshape(
            numpy.arange(self.Xbin * self.Ybin * self.Zbin),
            (self.Ybin, self.Xbin, self.Zbin),
        )
        m = float(self.Xbin) / self.Xres
        n = float(self.Ybin) / self.Yres
        l = float(self.Zbin) / self.Zres
        xi = numpy.int32(numpy.arange(0, self.Xbin, m))  # m and n are >0 and <= 1
        yi = numpy.int32(numpy.arange(0, self.Ybin, n))
        zi = numpy.int32(numpy.arange(0, self.Zbin, l))
        # print(xi.shape) # (512,)

        self.mask = numpy.zeros((self.Yres, self.Xres, self.Zres))
        self.mask = a[np.ix_(yi, xi, zi)]

    def fixationToSequence(self, data):
        d = data.copy()
        # print(d.shape) # (length, 4) for xyzt
        d[:, :3] -= self.Offset
        d[d < 0] = 0
        d[d[:, 0] >= self.Xres, 0] = self.Xres - 1
        d[d[:, 1] >= self.Yres, 1] = self.Yres - 1
        d[d[:, 2] >= self.Zres, 2] = self.Zres - 1
        d = self.int_vectorize(d)
        # the above steps are just for making sure all values is in bound
        seq_num = self.mask[d[:, 1], d[:, 0], d[:, 2]]

        # print(seq_num.shape) # (length, ) = [idx_1, idx2, ...]
        # temp bin means if we take duration into account.
        if self.TempBin != 0:
            fix_time = numpy.round(d[:, 3] / float(self.TempBin))
            # print(fix_time.shape, fix_time) # (length,) for example 10 length can have [ 7. 10. 10. 11. 12. 13. 14. 14. 17.  0.]
            tmp = []
            for f in range(d.shape[0]):
                tmp.extend([seq_num[f] for _ in range(int(fix_time[f]))])
            # an extra step where you further duplicate the position bin based on the occurance (duration) of a bin.
            # if a bin 1 occurs twice, then the new list is 1 1
            seq_num = numpy.array(tmp)
        return seq_num

    def match(self, A, B):
        n = len(A)
        m = len(B)

        # print("3.3.0")
        F = fill_alignment_matrix(A, B, self.SubMatrix, self.GapValue)

        # print("3.3.1")
        AlignmentA = numpy.zeros(n + m) - 1
        AlignmentB = numpy.zeros(n + m) - 1
        i = n
        j = m
        step = 0

        while i > 0 and j > 0:
            score = F[i, j]
            scoreDiag = F[i - 1, j - 1]
            # scoreUp = F[i, j-1]
            scoreLeft = F[i - 1, j]

            if score == scoreDiag + self.SubMatrix[A[i - 1], B[j - 1]]:
                AlignmentA[step] = A[i - 1]
                AlignmentB[step] = B[j - 1]
                i -= 1
                j -= 1
            elif score == scoreLeft + self.GapValue:
                AlignmentA[step] = A[i - 1]
                i -= 1
            else:
                AlignmentB[step] = B[j - 1]
                j -= 1

            step += 1
        # print("3.3.2")

        while i > 0:
            AlignmentA[step] = A[i - 1]
            i -= 1
            step += 1

        while j > 0:
            AlignmentB[step] = B[j - 1]
            j -= 1
            step += 1
        # print("3.3.3")
        F = F.transpose()

        maxF = numpy.max(F)
        maxSub = numpy.max(self.SubMatrix)
        scale = maxSub * max((m, n))
        matchScore = maxF / scale

        align = numpy.vstack(
            [AlignmentA[step - 1 :: -1], AlignmentB[step - 1 :: -1]]
        ).transpose()

        return matchScore, align, F
